{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Medical Concept Word Embedding Retrieval (HuggingFace)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/txgnn_env/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import csv\n",
    "from pyhealth.medcode.pretrained_embeddings.lm_emb.huggingface_retriever import embedding_retrieve as embedding_retriever\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import json\n",
    "import retrying\n",
    "from transformers import AutoTokenizer, AutoModel, BioGptTokenizer, BioGptForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embedding_retrieve(model, tokenizer, phrase):\n",
    "    # Encode the sentence\n",
    "    inputs = tokenizer(phrase, return_tensors='pt')\n",
    "\n",
    "    # Get the model's output \n",
    "    outputs = model(**inputs)\n",
    "\n",
    "    # Extract the embeddings\n",
    "    embedding = outputs.last_hidden_state.mean(dim=1)\n",
    "\n",
    "    # Now, `embedding` is a tensor that contains the embedding for your sentence.\n",
    "    # You can convert it to a numpy array if needed:\n",
    "    embedding = embedding.detach().numpy().tolist()[0]\n",
    "\n",
    "    return embedding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at microsoft/biogpt were not used when initializing BioGptModel: ['output_projection.weight']\n",
      "- This IS expected if you are initializing BioGptModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BioGptModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "# MODEL_NAME = \"bio_clinicalbert\"\n",
    "# TOKENIZER = AutoTokenizer.from_pretrained(\"emilyalsentzer/Bio_ClinicalBERT\")\n",
    "# MODEL = AutoModel.from_pretrained(\"emilyalsentzer/Bio_ClinicalBERT\")\n",
    "\n",
    "# MODEL_NAME = \"sapbert\"\n",
    "# TOKENIZER = AutoTokenizer.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")\n",
    "# MODEL = AutoModel.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")\n",
    "\n",
    "MODEL_NAME = \"biogpt\"\n",
    "TOKENIZER = AutoTokenizer.from_pretrained(\"microsoft/biogpt\")\n",
    "MODEL = AutoModel.from_pretrained(\"microsoft/biogpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @retrying.retry(stop_max_attempt_number=5000)\n",
    "def retrieve_embedding(term):\n",
    "    return embedding_retriever(MODEL, TOKENIZER, term)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Special Tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 21.25it/s]\n"
     ]
    }
   ],
   "source": [
    "st_id2emb = {}\n",
    "special_tokens = [\"<pad>\", \"<unk>\"]\n",
    "\n",
    "for token in tqdm(special_tokens):\n",
    "    emb = retrieve_embedding(term=token)\n",
    "    st_id2emb[token] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/special_tokens/special_tokens.json\", \"w\") as f:\n",
    "    json.dump(st_id2emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CCSCM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 285/285 [00:12<00:00, 22.10it/s]\n"
     ]
    }
   ],
   "source": [
    "ccscm_id2name = {}\n",
    "with open('../resource/CCSCM.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        ccscm_id2name[line[0]] = line[1].lower()\n",
    "\n",
    "ccscm_id2emb = {}\n",
    "for key in tqdm(ccscm_id2name.keys()):\n",
    "    emb = retrieve_embedding(term=ccscm_id2name[key])\n",
    "    ccscm_id2emb[key] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/conditions/ccscm.json\", \"w\") as f:\n",
    "    json.dump(ccscm_id2emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CCSPROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 231/231 [00:10<00:00, 22.86it/s]\n"
     ]
    }
   ],
   "source": [
    "ccsproc_id2name = {}\n",
    "with open('../resource/CCSPROC.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        ccsproc_id2name[line[0]] = line[1].lower()\n",
    "\n",
    "ccsproc_id2emb = {}\n",
    "for key in tqdm(ccsproc_id2name.keys()):\n",
    "    emb = retrieve_embedding(term=ccsproc_id2name[key])\n",
    "    ccsproc_id2emb[key] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/procedures/ccsproc.json\", \"w\") as f:\n",
    "    json.dump(ccsproc_id2emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ATC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6440/6440 [04:01<00:00, 26.69it/s]\n"
     ]
    }
   ],
   "source": [
    "atc_id2name = {}\n",
    "with open(\"../resource/ATC.csv\", newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        # if row['level'] == '3.0':\n",
    "        atc_id2name[row['code']] = row['name'].lower()\n",
    "\n",
    "atc_id2emb = {}\n",
    "for key in tqdm(atc_id2name.keys()):\n",
    "    i = 0\n",
    "    emb = retrieve_embedding(term=atc_id2name[key])\n",
    "    atc_id2emb[key] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/drugs/atc.json\", \"w\") as f:\n",
    "    json.dump(atc_id2emb, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "6920it [04:56, 23.31it/s]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Load the TSV file into a DataFrame\n",
    "df = pd.read_csv(\"/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ATC_Metadata.tsv\", sep='\\t')\n",
    "\n",
    "# Initialize an empty list to store the embeddings\n",
    "emb_list = []\n",
    "\n",
    "# Iterate through each row in the DataFrame\n",
    "for index, row in tqdm(df.iterrows()):\n",
    "    emb = retrieve_embedding(term=row['name'].lower())\n",
    "    emb_list.append(emb)\n",
    "\n",
    "# Specify the path to the output file\n",
    "output_file_path = \"/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ATC_embedding.tsv\"\n",
    "\n",
    "# Open the file with write mode\n",
    "with open(output_file_path, \"w\", newline='\\n') as file:\n",
    "    # Iterate through each embedding in emb_list\n",
    "    for emb in emb_list:\n",
    "        # Convert numerical values to string and join them with tab separator\n",
    "        line = '\\t'.join(map(str, emb))\n",
    "        # Write the line to the file\n",
    "        file.write(line + '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ICD9CM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 17736/17736 [11:34<00:00, 25.54it/s]\n"
     ]
    }
   ],
   "source": [
    "from pyhealth.medcode import ICD9CM\n",
    "\n",
    "icd9cm_id2name = {}\n",
    "with open('../resource/ICD9CM.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        icd9cm_id2name[line[0]] = line[2].lower()\n",
    "\n",
    "icd9cm_id2emb = {}\n",
    "for key in tqdm(icd9cm_id2name.keys()):\n",
    "    emb = retrieve_embedding(term=icd9cm_id2name[key])\n",
    "    icd9cm_id2emb[ICD9CM.standardize(key).replace('.', '')] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/conditions/icd9cm.json\", \"w\") as f:\n",
    "    json.dump(icd9cm_id2emb, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "17736it [13:30, 21.89it/s]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Load the TSV file into a DataFrame\n",
    "df = pd.read_csv(\"/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ICD9CM_Metadata.tsv\", sep='\\t')\n",
    "\n",
    "# Initialize an empty list to store the embeddings\n",
    "emb_list = []\n",
    "\n",
    "# Iterate through each row in the DataFrame\n",
    "for index, row in tqdm(df.iterrows()):\n",
    "    emb = retrieve_embedding(term=row['name'].lower())\n",
    "    emb_list.append(emb)\n",
    "\n",
    "# Specify the path to the output file\n",
    "output_file_path = \"/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ICD9CM_embedding.tsv\"\n",
    "\n",
    "# Open the file with write mode\n",
    "with open(output_file_path, \"w\", newline='\\n') as file:\n",
    "    # Iterate through each embedding in emb_list\n",
    "    for emb in emb_list:\n",
    "        # Convert numerical values to string and join them with tab separator\n",
    "        line = '\\t'.join(map(str, emb))\n",
    "        # Write the line to the file\n",
    "        file.write(line + '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ICD9PROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4671/4671 [03:12<00:00, 24.30it/s]\n"
     ]
    }
   ],
   "source": [
    "from pyhealth.medcode import ICD9PROC\n",
    "\n",
    "icd9proc_id2name = {}\n",
    "with open('../resource/ICD9PROC.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        icd9proc_id2name[line[0]] = line[2].lower()\n",
    "\n",
    "icd9proc_id2emb = {}\n",
    "for key in tqdm(icd9proc_id2name.keys()):\n",
    "    emb = retrieve_embedding(term=icd9proc_id2name[key])\n",
    "    icd9proc_id2emb[ICD9PROC.standardize(key).replace('.', '')] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/procedures/icd9proc.json\", \"w\") as f:\n",
    "    json.dump(icd9proc_id2emb, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4670it [03:40, 21.22it/s]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Load the TSV file into a DataFrame\n",
    "df = pd.read_csv(\"/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ICD9PROC_Metadata.tsv\", sep='\\t')\n",
    "\n",
    "# Initialize an empty list to store the embeddings\n",
    "emb_list = []\n",
    "\n",
    "# Iterate through each row in the DataFrame\n",
    "for index, row in tqdm(df.iterrows()):\n",
    "    emb = retrieve_embedding(term=row['name'].lower())\n",
    "    emb_list.append(emb)\n",
    "\n",
    "# Specify the path to the output file\n",
    "output_file_path = \"/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ICD9PROC_embedding.tsv\"\n",
    "\n",
    "# Open the file with write mode\n",
    "with open(output_file_path, \"w\", newline='\\n') as file:\n",
    "    # Iterate through each embedding in emb_list\n",
    "    for emb in emb_list:\n",
    "        # Convert numerical values to string and join them with tab separator\n",
    "        line = '\\t'.join(map(str, emb))\n",
    "        # Write the line to the file\n",
    "        file.write(line + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4671"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(icd9proc_id2emb.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/procedures/icd9proc.json\", \"r\") as f:\n",
    "    icd9proc_id2emb = json.load(f)\n",
    "\n",
    "icd9proc_id2emb_new = {}\n",
    "\n",
    "for key, value in icd9proc_id2emb.items():\n",
    "    icd9proc_id2emb_new[key.replace('.', '')] = value\n",
    "    icd9proc_id2emb_new['3605'] = icd9proc_id2emb['0066']\n",
    "    icd9proc_id2emb_new['3602'] = icd9proc_id2emb['36']\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/procedures/icd9proc.json\", \"w\") as f:\n",
    "    json.dump(icd9proc_id2emb_new, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"../resource/embeddings/LM/{MODEL_NAME}/procedures/icd9proc.json\", \"r\") as f:\n",
    "    icd9proc_id2emb = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'3602' in icd9proc_id2emb.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.16 ('txgnn_env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "79cb95e61c4f960f4e102f21c45668d32cb5c494b237694c15d64b50342e6e99"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
